{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "25390674",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PreTrainedTokenizerFast(name_or_path='t5-small', vocab_size=32100, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>', '<extra_id_42>', '<extra_id_43>', '<extra_id_44>', '<extra_id_45>', '<extra_id_46>', '<extra_id_47>', '<extra_id_48>', '<extra_id_49>', '<extra_id_50>', '<extra_id_51>', '<extra_id_52>', '<extra_id_53>', '<extra_id_54>', '<extra_id_55>', '<extra_id_56>', '<extra_id_57>', '<extra_id_58>', '<extra_id_59>', '<extra_id_60>', '<extra_id_61>', '<extra_id_62>', '<extra_id_63>', '<extra_id_64>', '<extra_id_65>', '<extra_id_66>', '<extra_id_67>', '<extra_id_68>', '<extra_id_69>', '<extra_id_70>', '<extra_id_71>', '<extra_id_72>', '<extra_id_73>', '<extra_id_74>', '<extra_id_75>', '<extra_id_76>', '<extra_id_77>', '<extra_id_78>', '<extra_id_79>', '<extra_id_80>', '<extra_id_81>', '<extra_id_82>', '<extra_id_83>', '<extra_id_84>', '<extra_id_85>', '<extra_id_86>', '<extra_id_87>', '<extra_id_88>', '<extra_id_89>', '<extra_id_90>', '<extra_id_91>', '<extra_id_92>', '<extra_id_93>', '<extra_id_94>', '<extra_id_95>', '<extra_id_96>', '<extra_id_97>', '<extra_id_98>', '<extra_id_99>']})\n",
      "{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}\n",
      "{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "#加载分词器\n",
    "tokenizer = AutoTokenizer.from_pretrained('t5-small')\n",
    "\n",
    "print(tokenizer)\n",
    "\n",
    "#编码试算\n",
    "print(\n",
    "    tokenizer.batch_encode_plus(\n",
    "        ['Hello, this one sentence!', 'This is another sentence.']))\n",
    "\n",
    "#label的编码方式,但是试验结果是和input的编码方式一样,没有区别\n",
    "with tokenizer.as_target_tokenizer():\n",
    "    print(\n",
    "        tokenizer.batch_encode_plus(\n",
    "            ['Hello, this one sentence!', 'This is another sentence.']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "aa11505b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached shuffled indices for dataset at datas/xsum/train/cache-5a721ded3b377201.arrow\n",
      "Loading cached shuffled indices for dataset at datas/xsum/validation/cache-a35139f4dcb69b4b.arrow\n",
      "Loading cached shuffled indices for dataset at datas/xsum/test/cache-522a74c9bc704175.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'document': \"Clay, who has agreed a two-year deal, made 39 appearances for Scottish Premiership club Motherwell last season after joining them in June 2016.\\nThe 25-year-old had spent the two previous seasons with Grimsby, playing 74 National League games.\\nClay is Leyton Orient's ninth signing since being relegated from League Two last season.\", 'summary': 'National League side Leyton Orient have signed Motherwell midfielder Craig Clay on a free transfer.', 'id': '40635923'}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['document', 'summary', 'id'],\n",
       "        num_rows: 20000\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['document', 'summary', 'id'],\n",
       "        num_rows: 100\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['document', 'summary', 'id'],\n",
       "        num_rows: 100\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset, load_from_disk\n",
    "\n",
    "#加载数据集\n",
    "#dataset = load_dataset('xsum')\n",
    "dataset = load_from_disk('datas/xsum')\n",
    "\n",
    "#采样,数据量太大了跑不动\n",
    "dataset['train'] = dataset['train'].shuffle(1).select(range(20000))\n",
    "dataset['validation'] = dataset['validation'].shuffle(1).select(range(100))\n",
    "dataset['test'] = dataset['test'].shuffle(1).select(range(100))\n",
    "\n",
    "print(dataset['train'][0])\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4abdde01",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at datas/xsum/train/cache-495b96586a2de104.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at datas/xsum/train/cache-b84649ce42fe08dd.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at datas/xsum/train/cache-2f914e8128c67ce2.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at datas/xsum/train/cache-674ddaa89e123e9b.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at datas/xsum/validation/cache-efcb6687daf3d8ba.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at datas/xsum/validation/cache-d6a0fcda0ff0735f.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at datas/xsum/validation/cache-0f76d8367084f805.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at datas/xsum/validation/cache-e05ad845b8e697e2.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at datas/xsum/test/cache-d807c068c9a4fabe.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at datas/xsum/test/cache-4c8896e0ad2e5251.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at datas/xsum/test/cache-b2007e056428bf6c.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at datas/xsum/test/cache-16b2d4bdf57a8a24.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': [21603, 10, 20988, 6, 113, 65, 4686, 3, 9, 192, 18, 1201, 1154, 6, 263, 6352, 3179, 7, 21, 12580, 6552, 2009, 1886, 8007, 2091, 336, 774, 227, 6109, 135, 16, 1515, 4619, 37, 944, 18, 1201, 18, 1490, 141, 1869, 8, 192, 1767, 9385, 28, 23427, 7, 969, 6, 1556, 3, 4581, 868, 3815, 1031, 5, 20988, 19, 312, 21220, 3, 16495, 31, 7, 24651, 8097, 437, 271, 3, 60, 8791, 26, 45, 3815, 2759, 336, 774, 5, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [868, 3815, 596, 312, 21220, 3, 16495, 43, 3814, 8007, 2091, 2076, 1846, 49, 12870, 20988, 30, 3, 9, 339, 2025, 5, 1]}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input_ids', 'attention_mask', 'labels'],\n",
       "        num_rows: 20000\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['input_ids', 'attention_mask', 'labels'],\n",
       "        num_rows: 100\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input_ids', 'attention_mask', 'labels'],\n",
       "        num_rows: 100\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#数据预处理函数,分词\n",
    "def f(examples):\n",
    "    #编码input\n",
    "    data = tokenizer.batch_encode_plus(\n",
    "        #在输入的前面加summarize前缀,这个是h5模型识别任务类型的标识\n",
    "        ['summarize: ' + i for i in examples['document']],\n",
    "        max_length=1024,\n",
    "        truncation=True,\n",
    "    )\n",
    "\n",
    "    #编码label\n",
    "    with tokenizer.as_target_tokenizer():\n",
    "        data['labels'] = tokenizer.batch_encode_plus(\n",
    "            examples['summary'],\n",
    "            max_length=128,\n",
    "            truncation=True,\n",
    "        )['input_ids']\n",
    "\n",
    "    return data\n",
    "\n",
    "\n",
    "dataset = dataset.map(function=f,\n",
    "                      batched=True,\n",
    "                      batch_size=1000,\n",
    "                      num_proc=4,\n",
    "                      remove_columns=['document', 'summary', 'id'])\n",
    "\n",
    "print(dataset['train'][0])\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e5afcae3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[21603,    10,    37,  3719,    13],\n",
       "        [21603,    10,  7086,  8408,   563]]), 'attention_mask': tensor([[1, 1, 1, 1, 1],\n",
       "        [1, 1, 1, 1, 1]]), 'labels': tensor([[10455,   120,    80,  -100],\n",
       "        [  301,    53,  4074,  1669]]), 'decoder_input_ids': tensor([[    0, 10455,   120,    80],\n",
       "        [    0,   301,    53,  4074]])}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#这个函数和下面这个工具类等价,但我也是模仿实现的,不确定有没有出入\n",
    "#from transformers import DataCollatorForSeq2Seq\n",
    "#DataCollatorForSeq2Seq(tokenizer, model=model)\n",
    "\n",
    "import torch\n",
    "\n",
    "\n",
    "#数据批处理函数\n",
    "def collate_fn(data):\n",
    "    #求最长的label\n",
    "    max_length = max([len(i['labels']) for i in data])\n",
    "\n",
    "    #把所有的label都补pad到最长\n",
    "    for i in data:\n",
    "        pads = [-100] * (max_length - len(i['labels']))\n",
    "        i['labels'] = i['labels'] + pads\n",
    "\n",
    "    #把多个数据整合成一个tensor\n",
    "    data = tokenizer.pad(\n",
    "        encoded_inputs=data,\n",
    "        padding=True,\n",
    "        max_length=None,\n",
    "        pad_to_multiple_of=None,\n",
    "        return_tensors='pt',\n",
    "    )\n",
    "\n",
    "    #定义decoder_input_ids\n",
    "    data['decoder_input_ids'] = torch.zeros_like(data['labels'],\n",
    "                                                 dtype=torch.long)\n",
    "    data['decoder_input_ids'][:, 1:] = data['labels'][:, :-1]\n",
    "    data['decoder_input_ids'][data['decoder_input_ids'] == -100] = 0\n",
    "\n",
    "    return data\n",
    "\n",
    "\n",
    "data = [{\n",
    "    'input_ids': [21603, 10, 37, 3719, 13],\n",
    "    'attention_mask': [1, 1, 1, 1, 1],\n",
    "    'labels': [10455, 120, 80]\n",
    "}, {\n",
    "    'input_ids': [21603, 10, 7086, 8408, 563],\n",
    "    'attention_mask': [1, 1, 1, 1, 1],\n",
    "    'labels': [301, 53, 4074, 1669]\n",
    "}]\n",
    "\n",
    "collate_fn(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bdacd9d0",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_ids torch.Size([8, 1024])\n",
      "attention_mask torch.Size([8, 1024])\n",
      "labels torch.Size([8, 35])\n",
      "decoder_input_ids torch.Size([8, 35])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "2500"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#数据加载器\n",
    "loader = torch.utils.data.DataLoader(\n",
    "    dataset=dataset['train'],\n",
    "    batch_size=8,\n",
    "    collate_fn=collate_fn,\n",
    "    shuffle=True,\n",
    "    drop_last=True,\n",
    ")\n",
    "\n",
    "for i, data in enumerate(loader):\n",
    "    break\n",
    "\n",
    "for k, v in data.items():\n",
    "    print(k, v.shape)\n",
    "\n",
    "len(loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "098c5e4a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7695.616\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor(3.9883, grad_fn=<NllLossBackward0>), torch.Size([8, 35, 32128]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoModelForSeq2SeqLM, T5Model\n",
    "\n",
    "#加载模型\n",
    "#model = AutoModelForSeq2SeqLM.from_pretrained('t5-small')\n",
    "\n",
    "\n",
    "#定义下游任务模型\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.pretrained = T5Model.from_pretrained('t5-small')\n",
    "\n",
    "        #本来应该写tokenizer.vocab_size=32100就可以了。\n",
    "        #但是预训练模型里的参数是32128，所以为了方便就直接使用这个尺寸了\n",
    "        self.fc = torch.nn.Linear(512, 32128, bias=False)\n",
    "\n",
    "        #加载预训练模型的参数\n",
    "        parameters = AutoModelForSeq2SeqLM.from_pretrained('t5-small')\n",
    "        self.fc.load_state_dict(parameters.lm_head.state_dict())\n",
    "\n",
    "        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)\n",
    "\n",
    "    def forward(self, input_ids, attention_mask, labels, decoder_input_ids):\n",
    "\n",
    "        logits = self.pretrained.encoder(input_ids=input_ids,\n",
    "                                         attention_mask=attention_mask)\n",
    "\n",
    "        logits = logits.last_hidden_state\n",
    "\n",
    "        logits = self.pretrained.decoder(\n",
    "            input_ids=decoder_input_ids,\n",
    "            encoder_hidden_states=logits,\n",
    "            encoder_attention_mask=attention_mask,\n",
    "        )\n",
    "\n",
    "        logits = logits.last_hidden_state\n",
    "\n",
    "        logits = logits * (512**-0.5)\n",
    "\n",
    "        logits = self.fc(logits)\n",
    "\n",
    "        loss = self.criterion(logits.reshape(-1, 32128),\n",
    "                              labels.reshape(-1))\n",
    "\n",
    "        return {'loss': loss, 'logits': logits}\n",
    "\n",
    "\n",
    "model = Model()\n",
    "\n",
    "#统计参数量\n",
    "print(sum(i.numel() for i in model.parameters()) / 10000)\n",
    "\n",
    "#模型试算\n",
    "out = model(**data)\n",
    "\n",
    "out['loss'], out['logits'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "20d9a12b",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pred= ('the memorial will be placed near the of thea man who fell in  into the river.  the  the the the the the the the the the the the the the the',)\n",
      "label= <pad> A monument will be built in memory of a student who died after falling into the River Avon.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "pred= (' trialsting child citizensex trafficender  in to have for agentscover agents  sex  a boy    picturesnography images  the   ',)\n",
      "label= <pad> A convicted British sex offender caught trying to pay US undercover officers for sex with a boy has admitted taking pornographic images into the country.\n",
      "\n",
      "pred= (' say investigating to arrest the73 49 named the death of a 21er  t   : the the the the the the the the the the the the the the',)\n",
      "label= <pad> Police are continuing to question a man over the murder of a pensioner in Bridgend county.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "pred= ('charity figures of patients waiting to treatment health treatment has d in 2011 past year years  figures show   the  charity charity charity charity charity charity charity charity charity charity charity charity charity',)\n",
      "label= <pad> The number of people waiting for mental health treatment has doubled in the past six years, figures have shown.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#测试\n",
    "def test():\n",
    "    model.eval()\n",
    "\n",
    "    #数据加载器\n",
    "    loader_test = torch.utils.data.DataLoader(\n",
    "        dataset=dataset['test'],\n",
    "        batch_size=4,\n",
    "        collate_fn=collate_fn,\n",
    "        shuffle=True,\n",
    "        drop_last=True,\n",
    "    )\n",
    "\n",
    "    for i, data in enumerate(loader_test):\n",
    "        break\n",
    "\n",
    "    #计算\n",
    "    with torch.no_grad():\n",
    "        out = model(**data)\n",
    "\n",
    "    for i in range(4):\n",
    "        input_ids = tokenizer.decode(data['input_ids'][i])\n",
    "        pred = tokenizer.decode(out['logits'].argmax(dim=2)[i]),\n",
    "        label = tokenizer.decode(data['decoder_input_ids'][i])\n",
    "\n",
    "        #print('input_ids=', input_ids)\n",
    "        print('pred=', pred)\n",
    "        print('label=', label)\n",
    "        print()\n",
    "\n",
    "\n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ae2ab50c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/cpu/lib/python3.6/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  FutureWarning,\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 3.93034291267395 1.9992e-05\n",
      "pred= ('the one  is the  is  ed as year  part of the breeding  breedingover the  the the the the the the the the the the the the the the the the',)\n",
      "label= <pad> Every single creature at Chester Zoo is being counted this week as part of its annual stock take.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "50 3.4090747833251953 1.9592e-05\n",
      "pred= ('whale was the first moment whenspotted whalehumpback whale breachingd off off the sea  off from thea whale  of  </s>  whale whale whale whale whale whale whale whale whale whale whale',)\n",
      "label= <pad> This is the spectacular moment a humpback whale breached high into the air just metres from a boat full of tourists.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "100 2.6107447147369385 1.9192000000000002e-05\n",
      "pred= ('Bennett judgeminhian judge, wasallegedlyaped a girlyear-old girl -d childs child, hisa suspected was been jailed for life yearsand-a-half years.</s>',)\n",
      "label= <pad> A West Lothian man who raped a 12-year-old and fathered a child with a teenager has been jailed for four-and-a-half years.\n",
      "\n",
      "150 3.2007808685302734 1.8792000000000002e-05\n",
      "pred= (\"Gabriel former scientist has been to trial for the case' for with ing  abetting  Pope.s wife butler, thea documentsl information from</s>: Gabriel    <extra_id_0> Cl   \",)\n",
      "label= <pad> A computer technician has gone on trial in the Vatican City charged with aiding and abetting the Pope's former butler in stealing papal documents.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "200 2.990238666534424 1.8392e-05\n",
      "pred= ('The I Economicnetary Fund (s  report Bank Forum ( the  is will will   for the. investorsa the pound.</s>:  the I the The The I I The The I The The The The I',)\n",
      "label= <pad> The International Monetary Fund's latest World Economic Outlook says the EU referendum campaign has created uncertainty for investors and weakened the pound.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "250 3.2237708568573 1.7992e-05\n",
      "pred= ('Ru turkey of the turkey Matthews firm firm will said was soldowed £, the was sold in have been  by will not be paid.</s></s> Ru the Ru The A Ru The The The The The',)\n",
      "label= <pad> Former suppliers to the Bernard Matthews turkey business, who were owed money before it was sold, have been told they will not be paid.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "300 2.6763546466827393 1.7592000000000004e-05\n",
      "pred= ('Sen newyuz satellite is been  the-er,  aim satellite Sen the world space.s orbit space-yearillion-dollaro--Observation satellite.</s>  the The Europe  The Europe European',)\n",
      "label= <pad> A Soyuz rocket has launched from French Guiana with the first satellite in the European Union's new multi-billion-euro Earth-observation programme.</s><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "350 2.7807886600494385 1.7192e-05\n",
      "pred= ('The UK\\' be \" ona \"big deal\" in the \",. the is to leave the EU Union   minister David Govon said said.</s></s> Cameron Cameron Cameron British Cameron Cameron ',)\n",
      "label= <pad> The UK will be taking a \"big gamble\" with its security if it votes to leave the European Union, defence secretary Michael Fallon has claimed.</s><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "400 3.089730739593506 1.6792e-05\n",
      "pred= (\"Police say the Indian capital of Bangalore say arrested   of a van   from s and have  of takingstealing money9 million rupees.13m) a£25,000'</s> the ban2,000 rupee notes,</s>\",)\n",
      "label= <pad> Police in the Indian city of Bangalore have arrested the driver of a van carrying cash for ATMs who is accused of stealing 9.2m rupees ($134,000; <unk>£107,000) in new 2,000 rupee notes.\n",
      "\n",
      "450 3.2825069427490234 1.6392e-05\n",
      "pred= ('They bes\"\"\" are thees are track their  over. be the s to the the insects of being danger. withor University has say.</s>',)\n",
      "label= <pad> Tiny \"backpacks\" for bees to track their flight paths could help provide clues about why the species is in decline, Bangor University experts say.\n",
      "\n",
      "500 2.939751625061035 1.5992000000000002e-05\n",
      "pred= ('Iraq the March, 6,000 Iraqi forces haves in operations on Tikrit, thea major inkm northandkm) from of Baghdad. is been occupied by IS State.S) forces the 2014.</s>',)\n",
      "label= <pad> On 1 March about 27,000 Iraqi troops commenced their attack on Tikrit, a city 150km (93 miles) north of Baghdad that has been occupied by Islamic State (IS) since June 2014.\n",
      "\n",
      "550 2.678152322769165 1.5592e-05\n",
      "pred= ('A least two people have Sind Sind have died in ings drinking inly inbre alcohol in police say.</s></s> police Police  Police   Police A   A',)\n",
      "label= <pad> At least 24 people in southern Pakistan have died from poisoning after drinking illegally-made alcohol, police say.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "600 2.8195526599884033 1.5192000000000003e-05\n",
      "pred= ('The band of Bo Bowie, been  by a tribute to in thea memorial- in</s> of  The The A The The The The The The The The The The The The The The A The The The A The The',)\n",
      "label= <pad> The life of David Bowie has been celebrated with a tribute concert at a London chapel.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "650 2.729532480239868 1.4792000000000002e-05\n",
      "pred= (\"The film has beenreportedly been added to thelern' a film about Sa author of theer in the Rye, which  was in New US.  audiences.</s></s> The\",)\n",
      "label= <pad> New material has reportedly been added to Salinger, a documentary about the author of Catcher in the Rye, after it opened in the US to mixed reviews.</s><pad>\n",
      "\n",
      "700 2.656344175338745 1.4392000000000002e-05\n",
      "pred= (\"The Niven'C Royalaire has the Gold Gold with Royal City Sunday.</s></s> The The The A The The The The The The The The The The The The The The The The The\",)\n",
      "label= <pad> Peter Niven-owned Clever Cookie won the Yorkshire Cup at York on Friday.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "750 3.025481700897217 1.3992000000000001e-05\n",
      "pred= ('Thelanda University the the India, is the Indianruinsdent university for Indian education in ago it was no major who the., Cambridge.s first university. indgna.</s></s> The The',)\n",
      "label= <pad> Nalanda, in northern India, was an eminent centre of higher education long before there were any students at Oxford, Cambridge or Europe's oldest university, Bologna.</s><pad><pad>\n",
      "\n",
      "800 3.4474802017211914 1.3592000000000001e-05\n",
      "pred= ('A of the new planets that-howcasingd out the orbit systems is been ped in planet of thethe likely-like\"s\" aer say said.</s>',)\n",
      "label= <pad> One of eight new planets spied in distant solar systems has usurped the title of \"most Earth-like alien world\", astronomers have said.\n",
      "\n",
      "850 3.0057315826416016 1.3192e-05\n",
      "pred= (\"TheUP'man Ni Lock Lockhart has spoken that was a losscarriage in</s>: The The The The The A The The The The The The The The The The The\",)\n",
      "label= <pad> DUP assembly member Carla Lockhart has revealed she suffered a miscarriage.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "900 2.6306569576263428 1.2792e-05\n",
      "pred= ('Ever Cup sidesrs Everovil were a  with theton on.  the points behind at Ever Premierbeatenlegation zone.</s></s></s> Ever Ever Ever Ever Ever Ever',)\n",
      "label= <pad> League Two strugglers Yeovil secured a draw at Luton Town to move four points clear of the relegation zone.</s><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "950 2.9503605365753174 1.2392000000000003e-05\n",
      "pred= (' match Cup tiefinalfinals of Middle and Manchester United will Lincoln will Tottenhamwall will be live on on  Sports Sports</s></s> Chelsea   Premier   Chelsea  Premier Premier Premier Premier Premier',)\n",
      "label= <pad> The FA Cup quarter-finals between Chelsea and Manchester United and Tottenham and Millwall will be broadcast live on BBC One.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "1000 3.4620726108551025 1.1992000000000001e-05\n",
      "pred= ('A pellet killed been found in the back of  airgun to,</s>: A Poll A A A A A Poll A A A A A A A A A A A A A A A A A A A A A A Poll A A A A A A A',)\n",
      "label= <pad> A cat has been shot in the neck with an air rifle in Liverpool.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1050 2.843055009841919 1.1592000000000002e-05\n",
      "pred= ('Bristol West has  the own to  Bristol railtransie transit buses scheme route to the.</s></s> Bristol The The The The The The The The The The The  The The The The The The The The The The The A The The The The',)\n",
      "label= <pad> The government has given its approval to the planned \"rapid transit\" bus route through Bristol.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "1100 3.3547728061676025 1.1192e-05\n",
      "pred= ('The- Ireland- associations have secured a £ fund50m- Choice  Union Fund of</s></s> the The The The The The The The The The The The The A',)\n",
      "label= <pad> Two Northern Ireland housing bodies have secured a combined £280m from the European Investment Bank.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "1150 3.3273966312408447 1.0792000000000001e-05\n",
      "pred= ('Greece PrimeChancellor Angela Merkel has said shea deal between needed in Greece EU-al talks Greece.  bailout deal.</s></s> Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece Greece',)\n",
      "label= <pad> German Chancellor Angela Merkel has said a compromise is possible in the stand-off with Greece over its bailout terms.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "1200 2.9355387687683105 1.0392e-05\n",
      "pred= (\"Celtic haves captain Armstrong has  to see  debut in the's Scottish Cup semi with Celticers.  out on  season's final-final.</s>. A Celtic Celtic A Celtic Celtic\",)\n",
      "label= <pad> Celtic's Stuart Armstrong is eager to make his mark in Sunday's Scottish Cup meeting with Rangers after missing out on last year's semi-final.</s><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "1250 3.093193531036377 9.992e-06\n",
      "pred= ('South  of Lesotho has  up of of thelands, the workers the world of be  by by thebacks  the,..</s></s> The The South South The South',)\n",
      "label= <pad> The Kingdom of Lesotho is made up mostly of highlands where many of the villages can be reached only on horseback, by foot or light aircraft.</s><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "1300 2.929048776626587 9.592e-06\n",
      "pred= (\"Ospreys' Jordanone Jonesne will be be ruled to to the after he iss from injurya shoulder injury. according  coach Tim Tandy.</s></s>\",)\n",
      "label= <pad> Ospreys lock Alun Wyn Jones will not be rushed back into action as he recovers from a shoulder injury, says head coach Steve Tandy.</s>\n",
      "\n",
      "1350 2.468980312347412 9.192000000000001e-06\n",
      "pred= ('s are for Veronica 19-year-old woman have arrested a woman who connection with  disappearance.</s></s>  police       ',)\n",
      "label= <pad> Detectives looking for a 15-year-old girl have arrested a man in connection with her disappearance.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "1400 3.4944376945495605 8.792e-06\n",
      "pred= (\"A  driver died in into the over in  iclorry driver  way1 was the  hours of the' seriouslya woman policelor.year-old woman. police say said.</s>\",)\n",
      "label= <pad> The pedestrian who was run over and killed by an Iceland lorry on the M4 in the early hours of Tuesday was a local 66-year-old woman, police have said.\n",
      "\n",
      "1450 3.020413875579834 8.392e-06\n",
      "pred= ('The review of the response are with be with terrorist terrorists are the are UK aty\"ing\" the to be withdrawn from thes.  to a reportaked document.</s>',)\n",
      "label= <pad> A third of emergency vehicles equipped to deal with major contaminations in England, including \"dirty bombs\", are to be withdrawn on 31 December, according to a leaked document.\n",
      "\n",
      "1500 2.95689058303833 7.992e-06\n",
      "pred= ('A number of deaths drivers in in injured in by the road in Wales is droppedrisen by 3a since 2004 last year.</s></s> the Wales The Wales A A Welsh The The A Wales A',)\n",
      "label= <pad> The number of young people killed or seriously injured on the roads in Wales has risen by 8% in the past year.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "1550 3.032794952392578 7.592e-06\n",
      "pred= ('A Walker,s horsey Antarctic wons-aticd the in  the finalaven intakes in Newmarket.</s></s>  The The The The The A The A The The A A A A The A A A The The',)\n",
      "label= <pad> Ed Walker's Stormy Antarctic beat highly fancied Foundation to capture the Craven Stakes at Newmarket.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "1600 2.752349853515625 7.192e-06\n",
      "pred= (\"Conservative Conservative has passed  by' issue spending cuts  years MPs havebd the views' Brexit past of</s>  Conservative Conservative  Conservative Conservative Conservative Conservative Conservative Conservative Conservative Conservative Conservative Conservative The Conservative Conservative\",)\n",
      "label= <pad> The government has been defeated in Parliament on the EU budget after 53 Conservative MPs defied their party over the issue.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "1650 2.95678448677063 6.792000000000001e-06\n",
      "pred= ('A latests have criticizedacted to  to criticism about the recent attacks of attackss in have \"y of the into the conclusions. s the violencetru violence of</s>',)\n",
      "label= <pad> The German media have reacted with shock and concern to the recent spate of killings but are wary of jumping to wider conclusions and warn against giving in to fear.\n",
      "\n",
      "1700 3.232518196105957 6.392000000000001e-06\n",
      "pred= ('The is be been  to  health inspector to investigate the stop  death of a man-year-old man. Arnadaoe,. WalesHI state has been.</s></s>   A The A The A A  A The A The A A A The A A A The A The A A A The A A A A A A The A A',)\n",
      "label= <pad> It would have been difficult for mental health services to predict or prevent the killing of a 22-year-old woman in Caerphilly county, a review has said.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "1750 2.645395278930664 5.992e-06\n",
      "pred= ('A couple of a former who of murdering a former rival has said  had \"aever  towards to \"a has askednever regret\" of was a killed  affair with</s>',)\n",
      "label= <pad> The partner of a woman accused of murdering a love rival has said she was \"never violent\" and he had \"no idea\" she was allegedly having an affair.\n",
      "\n",
      "1800 3.1919140815734863 5.592000000000001e-06\n",
      "pred= ('The council committee has agreed  to a council development in  home for thea site  course. Greentna.</s></s>. The The The A The The The The',)\n",
      "label= <pad> A planning committee has recommended refusal for a residential development and nursing home on a former golf course in Gretna.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "1850 2.9238357543945312 5.1920000000000004e-06\n",
      "pred= ('Cardiff Slade has the club will Cardiff manager manager is be thea \" that  new to move the to the top League.</s></s> Cardiff Cardiff Cardiff Cardiff Cardiff Cardiff Cardiff Cardiff Cardiff Cardiff Cardiff Cardiff Cardiff Cardiff Cardiff',)\n",
      "label= <pad> Russell Slade believes his successor as Cardiff City boss will inherit a squad with the potential to win promotion to the Premier League.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "1900 2.913532257080078 4.792000000000001e-06\n",
      "pred= (\"The Supreme Court has ruled that favour of the estate familyranged couple of a wealthyn who companyheiriranton who thea court of' they be been for her wealthy couplesvorcing spouse.</s>\",)\n",
      "label= <pad> The Supreme Court has ruled in favour of the English estranged wife of a Nigerian oil tycoon in a case lawyers say could have implications for some wealthy divorcing couples.\n",
      "\n",
      "1950 2.7509372234344482 4.3920000000000005e-06\n",
      "pred= ('A former police driver has was  murders in been toa hearing Court hearing. the casegang security of comply  forces criminala.  policess.</s></s> Mr A A A A A A',)\n",
      "label= <pad> A Catholic taxi driver who survived two murder attempts has begun a High Court challenge over the alleged failure to investigate security force collusion with loyalist killers.</s><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2000 2.6884877681732178 3.992e-06\n",
      "pred= ('The council committee received  ised by failing it has education with extra needs needs in</s>: The The The The The The The The The A The The The The The The The The The',)\n",
      "label= <pad> A council has been repeatedly criticised for how it handles children with special educational needs.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "2050 3.6314761638641357 3.5920000000000005e-06\n",
      "pred= ('A  gngynoedd onddy ytholiad  yngw   e ahi an ai y ngi  yth yiw yynd r  ywdd  d   r y yaa yynos  o  angeh yi  yn bby<unk>ai</s></s></s></s></s></s></s></s></s></s></s></s></s> </s>',)\n",
      "label= <pad> Ers y cyhoeddiad bod etholiad ar y gweill, mae rhai yn dadlau bod Cymru ar ei cholled, gyda'r wasg Brydeinig weithiau'n anghofio cyfeirio at y ffaith nad yw pob polisi neu addewid yn berthnasol yr\n",
      "\n",
      "2100 2.844595193862915 3.192e-06\n",
      "pred= ('The sea scientistfoodl\"\" is being held by a \" death of  on   of thea er\\'.</s></s> the The The The The The The The The The The The',)\n",
      "label= <pad> A national \"gull summit\" is being proposed following a recent spate of attacks, including one on a pensioner in Cornwall.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "2150 3.2856357097625732 2.792e-06\n",
      "pred= ('A hass have been s a  fireblaze at  ston terminal plant in Northern Yorkshirelanlanshire.</s></s> the The The A A A The A A A A',)\n",
      "label= <pad> Fire crews have been tackling a large blaze at the Hunterston coal port in North Ayrshire.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "2200 3.05957293510437 2.392e-06\n",
      "pred= (\"A health issues in Hu and be £7 million in the funds to help children people'</s></s>  A A The A A A A A A A A A A A A A A A A A A A A A A A A\",)\n",
      "label= <pad> Mental health projects in England will receive £55 million from lottery funds to support young people.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "2250 2.5619454383850098 1.992e-06\n",
      "pred= ('A Fife have up to third place the Premier and with a draw1-0 win over Peterion Rovers at Westdction held were to1-0 at Ham to thenraer.</s></s> A',)\n",
      "label= <pad> East Fife moved up to third in Scottish League One with a 2-0 win over Albion Rovers while Airdireonians lost 2-1 at home to Stranraer.</s><pad>\n",
      "\n",
      "2300 2.8082973957061768 1.5920000000000002e-06\n",
      "pred= ('The for been madeveiled to help the of thea new brewery in thechester Arts  arts centre.</s></s> The The The The The The The The The The The The The The The The The The The The The The The',)\n",
      "label= <pad> Plans have been unveiled to turn part of a former brewery in Dorchester into an arts centre.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "2350 2.8484041690826416 1.1920000000000002e-06\n",
      "pred= ('A former of a former Speaker Parliaments woman has   a  speaker, but alone a woman. who</s></s> The She The The A A A The A A A A She A',)\n",
      "label= <pad> The spouse of a House of Commons Speaker is not normally a public figure, let alone a minor celebrity.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "2400 3.0417096614837646 7.920000000000001e-07\n",
      "pred= ('A people the most- in use as the and the, the, their to and, others.  of thea mosque-experiencetrious Muslim Muslim community of Islamfi..  is become own in theegal.</s>',)\n",
      "label= <pad> Many of the street vendors commonly seen in Italy, France and Spain selling sunglasses, bags and souvenirs are members of a highly industrious, entrepreneurial branch of Sufi Islam, which has its roots in Senegal.\n",
      "\n",
      "2450 3.037229537963867 3.92e-07\n",
      "pred= ('A Ministers Cameron has allsising\" by  government Company charity.ddy.hn... his theridden the about over  has said told.</s></s> The The The The The A A The A',)\n",
      "label= <pad> Prime Minister David Cameron was \"mesmerised\" by the Kids Company boss Camila Batmanghelidjh and over-ruled concerns raised, it has been claimed.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from transformers.trainer_pt_utils import get_parameter_names\n",
    "from transformers import AdamW\n",
    "from transformers.optimization import get_scheduler\n",
    "\n",
    "\n",
    "#训练\n",
    "def train():\n",
    "    #一个工具函数,能够获取模型中所有参数的名字列表\n",
    "    parameter_names = get_parameter_names(model, [torch.nn.LayerNorm])\n",
    "\n",
    "    #定义哪些参数参与weight_decay\n",
    "    parameter_names = [i for i in parameter_names if 'bias' not in i]\n",
    "\n",
    "    #根据参数要参与weight_decay来把参数分为两组\n",
    "    parameter_names = [\n",
    "        {\n",
    "            'params':\n",
    "            [p for i, p in model.named_parameters() if i in parameter_names],\n",
    "            'weight_decay':\n",
    "            1e-2,\n",
    "        },\n",
    "        {\n",
    "            'params': [\n",
    "                p for i, p in model.named_parameters()\n",
    "                if i not in parameter_names\n",
    "            ],\n",
    "            'weight_decay':\n",
    "            0.0,\n",
    "        },\n",
    "    ]\n",
    "\n",
    "    #定义优化器\n",
    "    optimizer = AdamW(parameter_names, betas=(0.9, 0.999), eps=1e-8, lr=2e-5)\n",
    "\n",
    "    #定义lr调整器\n",
    "    scheduler = get_scheduler(name='linear',\n",
    "                              num_warmup_steps=0,\n",
    "                              num_training_steps=len(loader),\n",
    "                              optimizer=optimizer)\n",
    "\n",
    "    model.train()\n",
    "    for i, data in enumerate(loader):\n",
    "        out = model(**data)\n",
    "        loss = out['loss']\n",
    "\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "\n",
    "        optimizer.step()\n",
    "        scheduler.step()\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        model.zero_grad()\n",
    "\n",
    "        if i % 50 == 0:\n",
    "            pred = tokenizer.decode(out['logits'].argmax(dim=2)[0]),\n",
    "\n",
    "            label = tokenizer.decode(data['decoder_input_ids'][0])\n",
    "\n",
    "            lr = optimizer.state_dict()['param_groups'][0]['lr']\n",
    "            print(i, loss.item(), lr)\n",
    "            print('pred=', pred)\n",
    "            print('label=', label)\n",
    "            print()\n",
    "\n",
    "        torch.save(model, 'models/4.长求总.model')\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2167896e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pred= ('Swansea City hass plans has pleased that will be made in the with  the stadium Stadium.</s></s> The The The The The The Swan Swan Swan Swan Swan Swan',)\n",
      "label= <pad> Swansea council's leader is optimistic progress can be made during talks about expanding the Liberty Stadium.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "pred= ('A President candidate candidate Pault Romney is looking to be  own presidentpresidential candidate mate,.-    of</s></s>   Republican Republican Republican Republican',)\n",
      "label= <pad> US Republican presidential candidate Mitt Romney is expected to name his vice-presidential running mate soon - possibly within days.</s><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "pred= ('A woman doctor haswhose car was found in  sea in  s andrelated injuries\" has family has said.</s></s> Tor A A A A A A A A A',)\n",
      "label= <pad> A junior doctor whose body was found in the sea suffered from \"work-related anxiety\", her sister has said.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>\n",
      "\n",
      "pred= ('Notshire haves formerrelegation to Division Premier of iss Division- has \"aboutmbararrassment\" says to BBC John the. Newell.</s>',)\n",
      "label= <pad> Nottinghamshire's relegation from the County Championship's top flight is \"embarrassing\", according to director of cricket Mick Newell.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model = torch.load('models/4.长求总.model')\n",
    "test()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
